// standard.cpp
#define _WIN32_DCOM
#include <iostream.h>
#include "Component\component.h"
#include "registry.h"

const REG_DATA g_regData[] = {
    { "CLSID\\{10000006-0000-0000-0000-000000000001}", 0, "PSSum" },
	{ "CLSID\\{10000006-0000-0000-0000-000000000001}\\InprocServer32", 0, (const char*)-1 }, 
	{ "CLSID\\{10000006-0000-0000-0000-000000000001}\\InprocServer32", "ThreadingModel", "Both" }, 
	{ "Interface\\{10000001-0000-0000-0000-000000000001}", 0, "ISum" }, 
	{ "Interface\\{10000001-0000-0000-0000-000000000001}\\ProxyStubClsid32", 0, "{10000006-0000-0000-0000-000000000001}" }, 
	{ "Interface\\{10000001-0000-0000-0000-000000000001}\\NumMethods", 0, "4" }, 
	{ 0, 0, 0 }
};

const CLSID CLSID_InsideDCOMStdProxy = {0x10000006,0x0000,0x0000,{0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x01}};
long g_cObjects = 0;
HINSTANCE g_hInstance;

class CRpcStubBuffer : public IRpcStubBuffer
{
public:
	// IUnknown
	ULONG __stdcall AddRef();
	ULONG __stdcall Release();
	HRESULT __stdcall QueryInterface(REFIID iid, void **ppv);

	// IRpcStubBuffer
	HRESULT __stdcall Connect(IUnknown* pUnknown);
	HRESULT __stdcall Invoke(RPCOLEMESSAGE* pMessage, IRpcChannelBuffer* pRpcChannel);
	IRpcStubBuffer* __stdcall IsIIDSupported(REFIID riid);
	ULONG __stdcall CountRefs();
	HRESULT __stdcall DebugServerQueryInterface(void**);
	void __stdcall DebugServerRelease(void*);
	void __stdcall Disconnect();

	CRpcStubBuffer(REFIID riid);
	~CRpcStubBuffer() { g_cObjects--; }

private:
	long m_cRef;
	long m_cConnection;
	ISum* m_pObj;
	IID m_iid;
};

CRpcStubBuffer::CRpcStubBuffer(REFIID riid) : m_cRef(0), m_cConnection(0)
{
	g_cObjects++;
	m_iid = riid;
}

HRESULT CRpcStubBuffer::QueryInterface(REFIID riid, void** ppv)
{
	if(riid == IID_IUnknown || riid == IID_IRpcStubBuffer)
	{
		cout << "Stub: IRpcStubBuffer::QueryInterface() for IRpcStubBuffer" << endl;
		*ppv = (IRpcStubBuffer*)this;
	}
	else
	{
		*ppv = NULL;
		return E_NOINTERFACE;
	}
	AddRef();
	return S_OK;
}

ULONG CRpcStubBuffer::AddRef()
{
	return ++m_cRef;
}

ULONG CRpcStubBuffer::Release()
{
	if(--m_cRef != 0)
		return m_cRef;
	delete this;
	return 0;
}

HRESULT CRpcStubBuffer::Connect(IUnknown* pUnknown)
{
	cout << "Stub: IRpcStubBuffer::Connect " << pUnknown << endl;
	m_cConnection++;
	return pUnknown->QueryInterface(m_iid, (void**)&m_pObj);
}

HRESULT CRpcStubBuffer::Invoke(RPCOLEMESSAGE* pMessage, IRpcChannelBuffer* pRpcChannel)
{
	cout << "Stub: IRpcStubBuffer::Invoke cbBuffer = " << pMessage->cbBuffer << endl;
	switch(pMessage->iMethod) // What method of ISum is the proxy calling?
	{
	case 3: // The proxy is calling ISum::Sum
		int result;
		m_pObj->Sum(((int*)pMessage->Buffer)[0], ((int*)pMessage->Buffer)[1], &result);
		pMessage->cbBuffer = sizeof(int);
		pRpcChannel->GetBuffer(pMessage, m_iid);
		((int*)pMessage->Buffer)[0] = result;
		return NOERROR;
	// case other methods here...
	}
	return E_UNEXPECTED;
}

IRpcStubBuffer* CRpcStubBuffer::IsIIDSupported(REFIID riid)
{
	if(riid == m_iid)
		return (IRpcStubBuffer*)true;
	return (IRpcStubBuffer*)false;
}

ULONG CRpcStubBuffer::CountRefs()
{
	return m_cConnection;
}

HRESULT CRpcStubBuffer::DebugServerQueryInterface(void**)
{
	return E_NOTIMPL;
}

void CRpcStubBuffer::DebugServerRelease(void*)
{
}

void CRpcStubBuffer::Disconnect()
{
	cout << "Stub: IRpcStubBuffer::Disconnect" << endl;
	m_pObj->Release();
	m_cConnection--;
}

interface INoAggregationRpcProxyBuffer
{
	virtual HRESULT __stdcall QueryInterface_NoAggregation(REFIID riid, void** ppv)=0;
	virtual ULONG __stdcall AddRef_NoAggregation()=0;
	virtual ULONG __stdcall Release_NoAggregation()=0;
	virtual HRESULT __stdcall Connect(IRpcChannelBuffer* pRpcChannel)=0;
	virtual void __stdcall Disconnect(void)=0;
};

class CRpcProxyBuffer : public IRpcProxyBuffer, public ISum, public INoAggregationRpcProxyBuffer
{
public:
	// IUnknown
	ULONG __stdcall AddRef();
	ULONG __stdcall Release();
	HRESULT __stdcall QueryInterface(REFIID riid, void** ppv);

	// INoAggregationRpcProxyBuffer
	ULONG __stdcall AddRef_NoAggregation();
	ULONG __stdcall Release_NoAggregation();
	HRESULT __stdcall QueryInterface_NoAggregation(REFIID riid, void** ppv);

	// IRpcProxyBuffer
	HRESULT __stdcall Connect(IRpcChannelBuffer* pRpcChannel);
	void __stdcall Disconnect(void);

	// ISum
	HRESULT __stdcall Sum(int x, int y, int* retval);

	CRpcProxyBuffer(IUnknown* pUnknownOuter);
	~CRpcProxyBuffer();

private:
	long m_cRef;
	IRpcChannelBuffer* m_pRpcChannel;
	IUnknown* m_pUnknownOuter;
};

CRpcProxyBuffer::CRpcProxyBuffer(IUnknown* pUnknownOuter) : m_cRef(0)
{
	g_cObjects++;
	m_pUnknownOuter = pUnknownOuter;
}

CRpcProxyBuffer::~CRpcProxyBuffer()
{
	g_cObjects--;
}

HRESULT CRpcProxyBuffer::QueryInterface(REFIID riid, void** ppv)
{
	return m_pUnknownOuter->QueryInterface(riid, ppv);
}

ULONG CRpcProxyBuffer::AddRef()
{
	return m_pUnknownOuter->AddRef();
}

ULONG CRpcProxyBuffer::Release()
{
	return m_pUnknownOuter->Release();
}

HRESULT CRpcProxyBuffer::QueryInterface_NoAggregation(REFIID riid, void** ppv)
{
	if(riid == IID_IUnknown || riid == IID_IRpcProxyBuffer)
	{
		cout << "Proxy: IRpcProxyBuffer::QueryInterface() for IRpcProxyBuffer" << endl;
		*ppv = (INoAggregationRpcProxyBuffer*)this;
	}
	else if(riid == IID_ISum)
	{
		cout << "Proxy: IRpcProxyBuffer::QueryInterface() for ISum" << endl;
		*ppv = (ISum*)this;
	}
	else
	{
		*ppv = NULL;
		return E_NOINTERFACE;
	}
	((IUnknown*)(*ppv))->AddRef();
	return S_OK;
}

ULONG CRpcProxyBuffer::AddRef_NoAggregation()
{
	return ++m_cRef;
}

ULONG CRpcProxyBuffer::Release_NoAggregation()
{
	if(--m_cRef != 0)
		return m_cRef;
	delete this;
	return 0;
}

HRESULT CRpcProxyBuffer::Connect(IRpcChannelBuffer* pRpcChannel)
{
	cout << "Proxy: IRpcProxyBuffer::Connect " << pRpcChannel << endl;
	m_pRpcChannel = pRpcChannel;
	m_pRpcChannel->AddRef();
	return S_OK;
}

void CRpcProxyBuffer::Disconnect()
{
	m_pRpcChannel->Release();
	m_pRpcChannel = NULL;
	cout << "Proxy: IRpcProxyBuffer::Disconnect" << endl;
}

HRESULT CRpcProxyBuffer::Sum(int x, int y, int* retval)
{
	cout << "IRpcProxyBuffer::Sum" << endl;
	RPCOLEMESSAGE Message = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
	ULONG status;
	Message.cbBuffer = sizeof(int)*2;
	Message.iMethod = 3; // VTBL entry for ISum::Sum
	m_pRpcChannel->GetBuffer(&Message, IID_ISum);
	((int*)Message.Buffer)[0] = x;
	((int*)Message.Buffer)[1] = y;
	m_pRpcChannel->SendReceive(&Message, &status);
	cout << "Proxy: IRpcProxyBuffer::Sum Message = " << ((int*)Message.Buffer)[0] << " status = " << status << endl;
	*retval = ((int*)Message.Buffer)[0];
	m_pRpcChannel->FreeBuffer(&Message);
	return S_OK;
}

class CPSFactoryBuffer : public IPSFactoryBuffer
{
public:
	// IUnknown
	ULONG __stdcall AddRef();
	ULONG __stdcall Release();
	HRESULT __stdcall QueryInterface(REFIID riid, void** ppv);

	// IPSFactoryBuffer
    HRESULT __stdcall CreateProxy(IUnknown* pUnknownOuter, REFIID riid, IRpcProxyBuffer** ppProxy, void** ppv);
	HRESULT __stdcall CreateStub(REFIID riid, IUnknown* pUnkServer, IRpcStubBuffer** ppStub);

	CPSFactoryBuffer() : m_cRef(0) { };
	~CPSFactoryBuffer() { };

private:
	long m_cRef;
};

HRESULT CPSFactoryBuffer::CreateProxy(IUnknown* pUnknownOuter, REFIID riid, IRpcProxyBuffer** ppProxy, void** ppv)
{
	CRpcProxyBuffer* pRpcProxyBuffer = new CRpcProxyBuffer(pUnknownOuter);
	cout << "Proxy: CFactory::CreateProxy() " << pRpcProxyBuffer << endl;

	HRESULT hr = pRpcProxyBuffer->QueryInterface_NoAggregation(riid, ppv); // For IID_ISum
	return pRpcProxyBuffer->QueryInterface_NoAggregation(IID_IRpcProxyBuffer, (void**)ppProxy);
}

HRESULT CPSFactoryBuffer::CreateStub(REFIID riid, IUnknown* pUnkServer, IRpcStubBuffer** ppStub)
{
	cout << "Stub: CFactory::CreateStub() " << pUnkServer << endl;
	CRpcStubBuffer *pRpcStubBuffer = new CRpcStubBuffer(riid);
	pRpcStubBuffer->Connect(pUnkServer);
	return pRpcStubBuffer->QueryInterface(IID_IRpcStubBuffer, (void**)ppStub);
}

ULONG CPSFactoryBuffer::AddRef()
{
	return ++m_cRef;
}

ULONG CPSFactoryBuffer::Release()
{
	if(--m_cRef != 0)
		return m_cRef;
	delete this;
	return 0;
}

HRESULT CPSFactoryBuffer::QueryInterface(REFIID riid, void** ppv)
{
	if(riid == IID_IUnknown || riid == IID_IPSFactoryBuffer)
	{
		cout << "Proxy/Stub: CFactory::QueryInteface() for IPSFactoryBuffer" << endl;
		*ppv = (IPSFactoryBuffer*)this;
	}
	else 
	{
		*ppv = NULL;
		return E_NOINTERFACE;
	}
	AddRef();
	return S_OK;
}

HRESULT __stdcall DllCanUnloadNow()
{
	cout << "Proxy/Stub: DllCanUnloadNow() " << (g_cObjects == 0 ? "Yes" : "No") << endl;
	if(g_cObjects == 0)
		return S_OK;
	else
		return S_FALSE;
}

HRESULT __stdcall DllGetClassObject(REFCLSID clsid, REFIID riid, void** ppv)
{
	cout << "Proxy/Stub: DllGetClassObject" << endl;
	
	if(clsid != CLSID_InsideDCOMStdProxy)
		return CLASS_E_CLASSNOTAVAILABLE;

	CPSFactoryBuffer* pPSFactoryBuffer = new CPSFactoryBuffer;
	if(pPSFactoryBuffer == NULL)
		return E_OUTOFMEMORY;

	// QueryInterface for IPSFactoryBuffer
	return pPSFactoryBuffer->QueryInterface(riid, ppv);
}

HRESULT __stdcall DllRegisterServer()
{
	char DllPath[MAX_PATH];
	GetModuleFileName(g_hInstance, DllPath, sizeof(DllPath));

	return RegisterServerEx(g_regData, DllPath);
}

HRESULT __stdcall DllUnregisterServer()
{
	return UnregisterServerEx(g_regData);
}

BOOL WINAPI DllMain(HINSTANCE hInstance, DWORD dwReason, void* pv)
{
	cout << "Proxy/Stub: DllMain() Reason is " << dwReason << endl;
	g_hInstance = hInstance;
	return TRUE;
}